import numpy as np
from function import quadraticWithHVP, logistic
import copy

# Inexact damped newton method as in the DiSCO paper
# We solve the quadratic subproblem associated
# with hessian inversion using Local SGD
# F is the objective function's stochastic oracle
# X, Y are the data matrices
# T is the horizon for the newton algorithms
# alpha is a step-size parameter for teh newton steps
# M, K, R are the local SGD parameters (see below)
# lr is the learning rate for local SGD
# mu is the regularization constant for LR


def inexact_newton(F, X, Y, T, alpha, M, K, R, lr, momentum=0, mu=1e-3, quadSolver="localSGD", damp=True, gap=5):
    n, d = X.shape

    # matrix for storing all the newton weights
    W = np.zeros((T+1, d))

    # sotring the losses for each newton step
    losses = []
    loss = F(X, Y, W[0].reshape((d, 1)), mu, order=5)
    losses.append(loss)
    # print(f"[*] Initial loss is given by {loss}")

    for t in range(T):
        # print(f"[*] In {t}-th iterate for inexact Newton.")

        u = np.asmatrix(W[t]).T
        if quadSolver == "localSGD":
            Delta_t = np.asmatrix(
                localSGD(F, X, Y, M, K, R, mu, u, lr, momentum, forHVP=True)[-1]).T
        elif quadSolver == "fedAC":
            eta = lr
            gamma = max(np.sqrt(eta/(mu*K)), eta)
            alpha_1 = 1/(gamma*mu)
            beta = alpha_1 + 1
            Delta_t = np.asmatrix(
                fedAC(F, X, Y, M, K, R, mu, alpha_1, beta, eta, gamma, u, forHVP=True)[-1]).T
        # print(Delta_t)

        # reverse scaling with Newton decrement
        # alpha should be in the range [0.5, 1]
        _, hvpF = F(X, Y, u, mu, Delta_t, order=3)
        if damp:
            eta_t = alpha / \
                (1 +
                 np.sqrt(Delta_t.T.dot(hvpF).item(0)))
        else:
            eta_t = alpha
        # print(f"eta_t is {eta_t}")
        # updating with the obtained direction
        W[t+1] = W[t] - eta_t * Delta_t.T
        # print(W[t+1])

        if (t+1) % gap == 0:
            # store the loss
            loss = F(X, Y, W[t+1].reshape((d, 1)), mu, order=5)
            losses.append(loss)
            # print(f"[+] Loss incurred is {loss}")

    return W, losses

# local SGD algorithm which allows for using Hessian Vector products
# F is the objective function's stochastic oracle
# X, Y are the data matrices
# M (machines), K (local steps), R (communication rounds)
# are the local SGD parameters as usual
# u is the input for Q(v, u), i.e., the input for F
# lr is the constant learning rate for local SGD
# withHVP decides if a hessian vector product oracle has to be used


def localSGD(F, X, Y, M, K, R, mu, u=None, lr=0.1, momentum=0, forHVP=True, gap=5):

    n, d = X.shape

    # Matrix storing instantaneous weights on
    # each machine between communication rounds
    V = np.zeros((M, d))

    # Matrix storing all the weights after communication
    W = np.zeros((R + 1, d))

    if not forHVP:
        losses = []
        loss = F(X, Y, W[0].reshape((d, 1)), mu, order=5)
        losses.append(loss)
        # print(f"[*] Initial loss is given by {loss}")

    # loop for R communication rounds
    for r in range(R):
        # if not forHVP:
        #     print(f"[*] In {r}-th communication round for local SGD.")
        # loop for simulating parallel computation on machines
        for m in range(M):
            # initialize the machine's local iterate at the latest
            # synchronized iterate
            V[m] = copy.copy(W[r])
            # loop for local steps on each machine
            for k in range(K):
                # computing the gradient with the HVP oracle
                if forHVP:
                    _, grad = quadraticWithHVP(
                        F, X, Y, u, mu, np.asmatrix(V[m]).T, order=1)
                else:
                    _, grad = F(X, Y, np.asmatrix(V[m]).T, mu, order=1)
                    # print(f"Norm of the gradient {np.linalg.norm(grad)}")

                # making the update on the machine
                if k == 0:
                    prev = copy.copy(V[m])
                    V[m] = V[m] - lr * grad.T
                else:
                    temp = copy.copy(V[m])
                    V[m] = V[m] - lr * grad.T + momentum * (V[m] - prev)
                    prev = copy.copy(temp)

        # averaging the iterates on all the machines at the end
        # of the communication round
        W[r+1] = np.mean(V, axis=0)
        # print(f"norm of outer iterate {np.linalg.norm(W[r+1])}")
        # print(f"norm of the full gradient {F(X, Y, np.asmatrix(W[r+1]))}")

        # storing the loss
        if (not forHVP) and ((r+1) % gap == 0):
            u = copy.copy(W[r+1].reshape((d, 1)))
            loss = F(X, Y, u, mu, order=5)
            losses.append(loss)
            # print(f"[+] Loss incurred is {loss}")

    if forHVP:
        return W
    else:
        return W, losses


# Federated Accelerated SGD from the paper by Yuan and Ma'20
# This is a parallel variant of Ghadimi and Lan's famous
# mini-max optimal algorithm
# F is the custom order stochastic oracle for the objective function
# X, Y are the data matrices
# M, K, R are the local SGD problem characterizers
# mu is the L2 regularization parameter
# alpha, beta, eta and gamma are the hyper-parameters for the optimzer
# ver decides if Fedac1 is used or fedac2 is used

def fedAC(F, X, Y, M, K, R, mu, lr, ver=1, u=None, forHVP=False, gap=5):

    n, d = X.shape

    gamma = max(np.sqrt(lr/(mu*K)), lr)

    if ver == 1:
        alpha = 1/(gamma*mu)
        beta = alpha + 1
    elif ver == 2:
        alpha = 1.5/(gamma*mu) - 0.5
        beta = (2 * alpha**2 - 1) / (alpha - 1)

    # matrices for different types of iterates
    W = np.zeros((M, d))
    W_ag = np.zeros((M, d))
    W_md = np.zeros((M, d))
    V = np.zeros((M, d))
    V_ag = np.zeros((M, d))
    W_avg = np.zeros((R+1, d))
    # W_avg[0] = np.random.randn(1, d)

    if not forHVP:
        # Initiating the loss matrix and adding the initial loss
        # We store loss after every communication round and at the
        # initial point. Thus R + 1 sized vector.
        losses = []
        loss = F(X, Y, W_avg[0].reshape((d, 1)), mu, order=5)
        losses.append(loss)
        # print(f"[*] Initial loss is given by {loss}")

    for r in range(R):
        # if not forHVP:
            # print(f"[*] In {r}-th communication round for FedAc.")
        for m in range(M):
            for k in range(K):
                W_md[m] = W[m]/float(beta) + (1-1.0/beta)*W_ag[m]
                if not forHVP:
                    _, grad = F(X, Y, np.asmatrix(W_md[m]).T, mu, order=1)
                else:
                    _, grad = quadraticWithHVP(
                        F, X, Y, u, mu, np.asmatrix(W_md[m]).T, order=1)
                V_ag[m] = W_md[m] - lr*grad.T
                V[m] = (1-1/alpha)*W[m] + W_md[m]/alpha - gamma*grad.T
                W[m] = V[m]
                W_ag[m] = V_ag[m]
        W = np.tile(np.mean(V, axis=0), (M, 1))
        W_ag = np.tile(np.mean(V_ag, axis=0), (M, 1))
        W_avg[r+1] = np.mean(W_ag, axis=0)
        # print(quadraticWithHVP(F, X, Y, u, mu,
        #                        np.asmatrix(W_avg[r+1]).T, order=0))

        if (not forHVP) and ((r+1) % gap == 0):
            loss = F(X, Y, W_avg[r+1].reshape((d, 1)), mu, order=5)
            losses.append(loss)
            # print(f"[+] Loss incurred is {loss}")

    if not forHVP:
        return W_avg, losses
    else:
        return W


def newton(F, X, Y, steps, mu, t):

    n, d = X.shape
    w = np.zeros((d, 1))
    loss = F(X, Y, w, mu, order=5)
    print(f"Loss is {loss}")
    for step in range(steps):
        hess = F(X, Y, w, mu, order=4)
        grad = F(X, Y, w, mu, order=6)
        w = w - t * np.linalg.pinv(hess).dot(grad)
        loss = F(X, Y, w, mu, order=5)
        print(f"Loss is {loss}")

    return
